{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0",
   "metadata": {
    "id": "90a71fe6"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor\n",
    "from sklearn.neural_network import MLPRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {
    "id": "elegant-proxy",
    "papermill": {
     "duration": 0.011489,
     "end_time": "2021-03-30T21:54:42.895419",
     "exception": false,
     "start_time": "2021-03-30T21:54:42.883930",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Function Approximations by Trees and Neural Networks\n",
    "\n",
    "Here we show how the function\n",
    "$$\n",
    "x \\mapsto exp(4 x)\n",
    "$$\n",
    "can be easily approximated by a tree-based methods (Trees, Random Forest) and a neural network (2 Layered Neural Network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {
    "id": "376635dd"
   },
   "outputs": [],
   "source": [
    "# noiseless data\n",
    "def gen_data(n):\n",
    "    X = np.random.uniform(0, 1, size=(n, 1))\n",
    "    y = np.exp(4 * X[:, 0])\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {
    "id": "d608d3dd"
   },
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "X, y = gen_data(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {
    "id": "widespread-mention",
    "papermill": {
     "duration": 0.009467,
     "end_time": "2021-03-30T21:54:42.915858",
     "exception": false,
     "start_time": "2021-03-30T21:54:42.906391",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Functional Approximation by a Tree\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {
    "id": "c39fd891"
   },
   "outputs": [],
   "source": [
    "tr = DecisionTreeRegressor(ccp_alpha=0.2).fit(X, y)\n",
    "ypred = tr.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {
    "id": "1d16f253"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {
    "id": "0ecd8e8f"
   },
   "outputs": [],
   "source": [
    "tr = DecisionTreeRegressor(ccp_alpha=0.001).fit(X, y)\n",
    "ypred = tr.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {
    "id": "d725ab29"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {
    "id": "local-saturn",
    "papermill": {
     "duration": 0.013444,
     "end_time": "2021-03-30T21:54:43.953303",
     "exception": false,
     "start_time": "2021-03-30T21:54:43.939859",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Functional Approximation by RF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {
    "id": "international-serum",
    "papermill": {
     "duration": 0.01351,
     "end_time": "2021-03-30T21:54:43.980273",
     "exception": false,
     "start_time": "2021-03-30T21:54:43.966763",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Here we show how the function\n",
    "$$\n",
    "x \\mapsto exp(4 x)\n",
    "$$\n",
    "can be easily approximated by a tree-based method (Random Forest) and a neural network (2 Layered Neural Network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {
    "id": "2c89a0c5"
   },
   "outputs": [],
   "source": [
    "rf = RandomForestRegressor().fit(X, y)\n",
    "ypred = rf.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {
    "id": "fcfc51f2"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {
    "id": "infrared-belgium",
    "papermill": {
     "duration": 0.015474,
     "end_time": "2021-03-30T21:54:45.201078",
     "exception": false,
     "start_time": "2021-03-30T21:54:45.185604",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Boosted Trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {
    "id": "f5118cea"
   },
   "outputs": [],
   "source": [
    "gbf = GradientBoostingRegressor(n_estimators=100, learning_rate=0.01).fit(X, y)\n",
    "ypred = gbf.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {
    "id": "6c4f42fc"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {
    "id": "137e44e0"
   },
   "outputs": [],
   "source": [
    "gbf = GradientBoostingRegressor(n_estimators=1000, learning_rate=0.01).fit(X, y)\n",
    "ypred = gbf.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {
    "id": "1649101f"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "id": "psychological-venice",
    "papermill": {
     "duration": 0.018291,
     "end_time": "2021-03-30T21:54:47.087924",
     "exception": false,
     "start_time": "2021-03-30T21:54:47.069633",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Same Example with a Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {
    "id": "9c061994"
   },
   "outputs": [],
   "source": [
    "nnet = MLPRegressor((200, 20,), 'relu',\n",
    "                    learning_rate_init=0.01,\n",
    "                    batch_size=10, max_iter=1)\n",
    "nnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {
    "id": "2a1d7cc0"
   },
   "outputs": [],
   "source": [
    "nnet.fit(X, y)\n",
    "ypred = nnet.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {
    "id": "acf2ce5b"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22",
   "metadata": {
    "id": "a9db1d51"
   },
   "outputs": [],
   "source": [
    "nnet.max_iter = 100\n",
    "nnet.fit(X, y)\n",
    "ypred = nnet.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {
    "id": "539ebb01"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {
    "id": "f261e981"
   },
   "source": [
    "### Using the PyTorch Neural Network Library and its Sklearn API Skorch\n",
    "\n",
    "We first need to install skorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25",
   "metadata": {
    "id": "85047mypXqrD"
   },
   "outputs": [],
   "source": [
    "!pip install skorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26",
   "metadata": {
    "id": "49e3b46e"
   },
   "outputs": [],
   "source": [
    "import skorch\n",
    "from skorch import NeuralNetRegressor\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27",
   "metadata": {
    "id": "7defdcf2"
   },
   "outputs": [],
   "source": [
    "arch = nn.Sequential(nn.Linear(X.shape[1], 200), nn.ReLU(),\n",
    "                     nn.Linear(200, 20), nn.ReLU(),\n",
    "                     nn.Linear(20, 1))\n",
    "nnet = NeuralNetRegressor(arch, lr=0.01, batch_size=10, max_epochs=1,\n",
    "                          optimizer=torch.optim.Adam, train_split=None)\n",
    "nnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28",
   "metadata": {
    "id": "0f25a6ff"
   },
   "outputs": [],
   "source": [
    "nnet.fit(X.astype(np.float32), y.reshape(-1, 1).astype(np.float32))\n",
    "ypred = nnet.predict(X.astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29",
   "metadata": {
    "id": "d25c8bfe"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30",
   "metadata": {
    "id": "cf4b2ce6"
   },
   "outputs": [],
   "source": [
    "nnet.max_epochs = 100  # training for more\n",
    "nnet.fit(X.astype(np.float32), y.reshape(-1, 1).astype(np.float32))\n",
    "ypred = nnet.predict(X.astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31",
   "metadata": {
    "id": "74d5b869"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32",
   "metadata": {
    "id": "ff746610"
   },
   "outputs": [],
   "source": [
    "# adding early stopping based on validation set performance\n",
    "nnet.train_split = skorch.dataset.ValidSplit(5)  # 20% validation\n",
    "nnet.callbacks = [skorch.callbacks.EarlyStopping()]  # early stopping callback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33",
   "metadata": {
    "id": "510bfdb9"
   },
   "outputs": [],
   "source": [
    "nnet.fit(X.astype(np.float32), y.reshape(-1, 1).astype(np.float32))\n",
    "ypred = nnet.predict(X.astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34",
   "metadata": {
    "id": "8678cba2"
   },
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], y, label='true')\n",
    "plt.scatter(X[:, 0], ypred, label='pred')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35",
   "metadata": {
    "id": "1907bf04"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 30.682213,
   "end_time": "2021-03-30T21:55:10.531019",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2021-03-30T21:54:39.848806",
   "version": "2.3.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
